# %%[markdown]
# Simulation results visualization


# %%
import glob
import pickle
from pathlib import Path

import astropy.units as u
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from astropy.constants import R_earth
from pyquaternion import Quaternion

from adcs import quaternion_to_euler_angle
from ballistic import phasing_angle
from helpers import interpolate_positions
from mission_parameters import (
    ALPHA_GOAL,
    ANGLE_ACCURACY,
    MISSION_PREFIX,
    OMEGA_ACCURACY,
    RESULTS_FOLDER,
    VALVE_OPENING_DURATION,
)

# %%
# FIXME choose required
PLOTS_EXTENSION = "pdf"

PLOTS_FOLDER = f"../plots/{MISSION_PREFIX}"
Path(PLOTS_FOLDER).mkdir(parents=True, exist_ok=True)

# PLOTS_FOLDER = f"../articles/engjournal/images/plots/{MISSION_PREFIX}"
# PLOTS_EXTENSION = "png"


# %%
all_mission_files = glob.glob(f"{RESULTS_FOLDER}/*{MISSION_PREFIX}*")
# last_filename = "results/SixU-res-2021-01-24-21:56:13"
last_filename = max(all_mission_files)  # newest one with greater date
with open(last_filename, "rb") as res_input_file:
    if ALPHA_GOAL is not None:
        (
            sat_solutions,
            other_parameters,
            free_flight_phasing_sol,
            free_flight_single_sol,
        ) = pickle.load(res_input_file)
    else:
        sat_solutions, other_parameters = pickle.load(res_input_file)


# %%
def my_plot(
    times,
    values,
    title="",
    ylabel="",
    ylabels=[""],
    filename="",
    y_fmt=None,
    xlim=None,
    subfolder="",
):
    fig = plt.figure()
    ax = fig.add_subplot(111)

    if y_fmt is not None:
        ax.yaxis.set_major_formatter(y_fmt)
    ax.grid(True)

    # FIXME если есть подрисуночные надписи, то загологовок надо удалить
    # ax.set_title(title)
    ax.set_title("")

    for i in range(len(values)):
        t = times[i]
        ax.plot(t, values[i], label=ylabels[i])

    trange = ax.axes.get_xlim()
    if trange[1] - trange[0] > 86400:
        ax.set_xlabel("Время, сутки")
        days_fmt = matplotlib.ticker.FuncFormatter(lambda m, _: f"{m/86400:.0f}")
        ax.xaxis.set_major_formatter(days_fmt)
        ax.xaxis.set_major_locator(matplotlib.ticker.MultipleLocator(86400))
    else:
        ax.set_xlabel("Время, с")
        plt.locator_params(axis="x", nbins=5)

    if len(values) > 1:
        plt.legend()

    ax.set_ylabel(ylabel)

    if xlim is not None:
        plt.xlim(xlim)

    plt.show()
    fig.savefig(
        f"{PLOTS_FOLDER}{subfolder}/{filename}.{PLOTS_EXTENSION}",
        dpi=600,
        bbox_inches="tight",
    )


# %%
def get_mode_start_time(sol, mode, following_by_mode=None):
    t = sol[0]
    modes = sol[2]
    mode_start_time = None

    for i in range(1, len(t)):
        if following_by_mode is None:
            if modes[i] == mode:
                mode_start_time = t[i]
                break
        else:
            if modes[i] == mode and modes[i - 1] == following_by_mode:
                mode_start_time = t[i]
                break

    return mode_start_time


# %%
def graphs_plotting():
    """Plotting all graphs"""
    phase_number = 0
    sat_number = 0

    sol = sat_solutions[phase_number][sat_number]
    # ignore first point (zero time) which solver adds automatically to any
    # initial time
    # FIXME filter this in mission_simulation
    # sol[0] = sol[0][1:]
    # sol[1] = [vector[1:] for vector in sol[1]]
    # sol[2] = sol[2][1:]

    t = sol[0]
    internal_energy = sol[1][0]
    m_fuel = sol[1][1]
    x, y, z = sol[1][2:5]
    vx, vy, vz = sol[1][5:8]
    omega_1, omega_2, omega_3 = sol[1][8:11]
    q_0, q_1, q_2, q_3 = sol[1][11:15]
    wheels_rate = sol[1][15:19]
    electric_power = sol[1][19]

    (
        p_chamber,
        t_chamber,
        m_steam,
        m_liq,
        thrust,
        imp_sp,
        mass_flow,
        sun_b,
        o1_b,
        q_cmd,
        q_err,
        torque_cmd,
        rw_acc,
    ) = other_parameters[phase_number][sat_number]

    r_norm = np.linalg.norm([x, y, z], axis=0)
    altitude = r_norm - R_earth.to(u.m).value
    v_norm = np.linalg.norm([vx, vy, vz], axis=0)

    my_plot(
        [t],
        [internal_energy / 1000],
        title="Накопленная тепловая энергия",
        ylabel="Накопленная тепловая энергия, кДж",
        filename="thermal_energy",
    )

    my_plot(
        [t],
        [m_fuel],
        title="Масса рабочего тела",
        ylabel="Масса рабочего тела, кг",
        filename="fuel_mass",
    )

    my_plot(
        [t],
        [p_chamber],
        title="Давление в камере",
        ylabel="Давление в камере, атм.",
        filename="chamber_pressure",
    )

    my_plot(
        [t],
        [t_chamber],
        title="Температура в камере",
        ylabel="Температура в камере, К",
        filename="chamber_temperature",
    )

    my_plot(
        [t],
        [thrust],
        title="Тяга двигателя",
        ylabel="Тяга, Н",
        filename="thrust",
    )

    my_plot(
        [t],
        [imp_sp],
        title="Удельный импульс",
        ylabel="Удельный импульс, м/с",
        filename="specific_impulse",
    )

    # FIXME убрать скачок по высоте в последней точке
    kms_fmt = matplotlib.ticker.FuncFormatter(lambda m, _: f"{m*1e-3:.0f}")
    my_plot(
        [t[:-1]],
        [altitude[:-1]],
        title="Высота полета",
        ylabel="Высота полета, км",
        y_fmt=kms_fmt,
        filename="altitude",
    )

    # v_norm_ini = np.linalg.norm(satellite_initial_orbit_parameters()[3:6])
    # v_diff = v_norm - v_norm_ini
    # TODO поставить здесь другой порог по оси y
    # TODO убрать скачок по высоте в последней точке
    my_plot(
        [t[:-1]],
        # v_diff,
        [v_norm[:-1]],
        title="Скорость движения",
        ylabel="Скорость, м/c",
        filename="velocity",
    )

    my_plot(
        [t],
        [electric_power / 3600],
        title="Заряд аккумуляторной батареи",
        ylabel="Заряд, Вт·ч",
        filename="power",
    )

    yvariables = []
    ylabels = []
    times = []
    for q_el in range(4):
        yvariables.append([qi.elements[q_el] for qi in q_cmd])
        ylabels.append(f"$q_{{cmd{q_el}}}$")
        times.append(t)
    my_plot(
        times,
        yvariables,
        ylabels=ylabels,
        ylabel="Компоненты командного кватерниона",
        title="Командный кватернион",
        filename="command_quaternion",
    )

    ylabels = []
    times = []
    for i in range(4):
        ylabels.append(f"маховик {i}")
        times.append(t)
    my_plot(
        times,
        wheels_rate,
        ylabels=ylabels,
        ylabel="Угловые скорости маховиков, $c^{-1}$",
        title="Угловые скорости маховиков",
        filename="wheels_rate",
    )

    # ##################################
    # Локальные графики
    # ##################################

    first_thrusting_aft_orb_time = get_mode_start_time(
        sol, "thrusting", following_by_mode="orbital_orientation"
    )
    first_orbital_orientation_time = get_mode_start_time(sol, "orbital_orientation")

    x_left = first_thrusting_aft_orb_time - VALVE_OPENING_DURATION
    x_right = first_thrusting_aft_orb_time + 2 * VALVE_OPENING_DURATION

    my_plot(
        [t],
        [thrust],
        title="Тяга двигателя",
        ylabel="Тяга, Н",
        xlim=[x_left, x_right],
        filename="thrust_scaled_thrusting",
    )

    my_plot(
        [t],
        [imp_sp],
        title="Удельный импульс",
        ylabel="Удельный импульс, м/с",
        xlim=[x_left, x_right],
        filename="specific_impulse_scaled_thrusting",
    )

    x_left = first_orbital_orientation_time - VALVE_OPENING_DURATION
    x_right = first_thrusting_aft_orb_time + 2 * VALVE_OPENING_DURATION

    ylabels = []
    times = []
    for i in range(3):
        ylabels.append(f"$\\omega_{i}$")
        times.append(t)
    my_plot(
        times,
        [omega_1, omega_2, omega_3],
        ylabels=ylabels,
        ylabel="Угловая скорость КА, $c^{-1}$",
        title="Угловая скорость КА",
        xlim=[x_left, x_right],
        filename="omega_scale_orbital_orientation",
    )

    ylabels = []
    times = []
    for i in range(4):
        ylabels.append(f"$q_{i}$")
        times.append(t)
    my_plot(
        times,
        [q_0, q_1, q_2, q_3],
        ylabels=ylabels,
        ylabel="Кватернион ориентации КА",
        title="Кватернион ориентации КА",
        xlim=[x_left, x_right],
        filename="quaternion_scale_orbital_orientation",
    )

    q_bi_mat = np.transpose(np.array([q_0, q_1, q_2, q_3]))
    q_cmd_mat = np.transpose(np.array(q_cmd))
    q_bi_list = [Quaternion(q[0], q[1], q[2], q[3]) for q in q_bi_mat]
    q_cmd_list = [Quaternion(q[0], q[1], q[2], q[3]) for q in q_cmd_mat]
    q_err_list = [q_bi_list[i] * q_cmd_list[i].inverse for i in range(len(q_bi_list))]
    omega_norm = np.linalg.norm([omega_1, omega_2, omega_3], axis=0)
    q_err_angle_abs = np.array([abs(q_err.angle) for q_err in q_err_list])
    orientation_event_fun_value = (
        omega_norm + q_err_angle_abs - ANGLE_ACCURACY - OMEGA_ACCURACY
    )
    my_plot(
        [t],
        # v_diff,
        [orientation_event_fun_value],
        title="orientation_event_fun_value",
        xlim=[x_left - 50, x_right + 50],
        ylabel="",
        filename="orientation_event_fun_value",
    )

    Euler_angle_0, Euler_angle_1, Euler_angle_2 = quaternion_to_euler_angle(
        q_0, q_1, q_2, q_3
    )
    my_plot(
        [t] * 3,
        list(map(np.rad2deg, [Euler_angle_0, Euler_angle_1, Euler_angle_2])),
        ylabels=["Euler_angle_0", "Euler_angle_1", "Euler_angle_2"],
        ylabel="Углы Эйлера, градусы",
        title="Ориентация КА",
        xlim=[x_left, x_right],
        filename="euler_angles",
    )


graphs_plotting()


# %%
if MISSION_PREFIX in {"SixU", "verification"}:
    times = []
    values = []

    for i in range(2):
        r1_interp, r2_interp = interpolate_positions(
            sat_solutions[i][0], sat_solutions[i][1], free_flight_single_sol[i]
        )
        t_0 = sat_solutions[i][0][0][1]  # ignore first (zero index) point
        t_end = free_flight_single_sol[i].t[-1]
        t = np.arange(t_0, t_end, 0.01 * (t_end - t_0))
        dot_product = [
            phasing_angle(
                r1_interp(ti),
                r2_interp(ti),
            )
            for ti in t
        ]
        times.append(t)
        values.append(dot_product)

    t = free_flight_phasing_sol.t
    r1 = free_flight_phasing_sol.y[0:3]
    r2 = free_flight_phasing_sol.y[6:9]
    dot_product = [phasing_angle(r1[:, i], r2[:, i]) for i in range(len(t))]
    times.append(t)
    values.append(dot_product)

    times[1], times[2] = times[2], times[1]
    values[1], values[2] = values[2], values[1]

    my_plot(
        times,
        values,
        ylabels=[
            "перелет на орбиты фазирования",
            "фазирование в свободном полете",
            "возврат на начальную орбиту",
        ],
        ylabel="Скалярное произведение\nединичных радиус-векторов",
        title=(
            "Скалярное произведение единичных\n"
            "радиус-векторов двух космических аппаратов"
        ),
        filename="phasing_angle",
    )

# %%
# 3D visualization

fig_3D = plt.figure()
ax = plt.axes(projection="3d")

# Создание заголовка и подписей осей
ax.set_title("Движение спутников")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")

# Загрузка данных для визуализации
# Только первая фаза
sol_3D_1 = sat_solutions[0][0]
sol_3D_2 = sat_solutions[0][1]
x_3D_1, y_3D_1, z_3D_1 = sol_3D_1[1][2:5]
x_3D_2, y_3D_2, z_3D_2 = sol_3D_2[1][2:5]

# Построение графика
ax.plot3D(x_3D_1, y_3D_1, z_3D_1, c="green")
ax.plot3D(x_3D_2, y_3D_2, z_3D_2, c="blue")

# Задание пределов осей
xyz_limits = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()]).T
limits_min_max = np.asarray([min(xyz_limits[0]), max(xyz_limits[1])])
ax.set_xlim3d(limits_min_max)
ax.set_ylim3d(limits_min_max)
ax.set_zlim3d(limits_min_max * 3 / 4)

# Отображение графика
plt.show()
